import numpy as np
import torch
import xml.etree.ElementTree as ET
import model.config as cfg

from PIL import Image
from torch.utils.data import Dataset
from model.utils.augmentation import augment_img
from model.utils.names import CLASSES


class VOCDataset(Dataset):
    def __init__(self, txt_file):
        with open(txt_file, "r") as f:
            lines = f.readlines()

        self.image_list = [i.rstrip("\n") for i in lines]
        self.annotations = [
            self._load_pascal_annotation(img_file) for img_file in self.image_list
        ]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        im_data = Image.open(self.image_list[index])
        boxes = self.annotations[index]["boxes"]
        gt_classes = self.annotations[index]["gt_classes"]

        im_data, boxes, gt_classes = augment_img(im_data, boxes, gt_classes)

        w, h = im_data.size[0], im_data.size[1]
        boxes[:, 0::2] = np.clip(boxes[:, 0::2] / w, 0.001, 0.999)
        boxes[:, 1::2] = np.clip(boxes[:, 1::2] / h, 0.001, 0.999)

        # resize image
        input_h, input_w = cfg.input_size
        im_data = im_data.resize((input_w, input_h))
        im_data_resize = torch.from_numpy(np.array(im_data)).float() / 255
        im_data_resize = im_data_resize.permute(2, 0, 1)
        boxes = torch.from_numpy(boxes)
        gt_classes = torch.from_numpy(gt_classes)
        num_obj = torch.Tensor([boxes.size(0)]).long()
        return im_data_resize, boxes, gt_classes, num_obj

    def _load_pascal_annotation(self, img_path):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """
        filename = img_path.replace("JPEGImages", "Annotations").replace(".jpg", ".xml")
        tree = ET.parse(filename)
        objs = tree.findall("object")
        num_objs = len(objs)

        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objs):
            bbox = obj.find("bndbox")
            # Make pixel indexes 0-based
            x1 = float(bbox.find("xmin").text) - 1
            y1 = float(bbox.find("ymin").text) - 1
            x2 = float(bbox.find("xmax").text) - 1
            y2 = float(bbox.find("ymax").text) - 1

            cls = obj.find("name").text.lower().strip()
            cls_id = CLASSES.index(cls)
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls_id

        return {
            "boxes": boxes,
            "gt_classes": gt_classes,
        }


def detection_collate(batch):
    """
    Collate data of different batch, it is because the boxes and gt_classes
    have changeable length. This function will pad the boxes and gt_classes
    with zero.

    Arguments:
        batch: list of tuple (im_data, boxes, gt_classes, num_obj)
          - im_data   : (3, H, W)
          - boxes     : (N, 4)
          - gt_classes: (N)
          - num_obj   : (1)

    Returns:
        tuple
          - (batch_size, 3, H, W)
          - (batch_size, N, 4)
          - (batch_size, N)
          - (batch_size, 1)
    """

    # kind of hack, this will break down a list of tuple into
    # individual list
    bsize = len(batch)
    im_data, boxes, gt_classes, num_obj = zip(*batch)
    max_num_obj = max([x.item() for x in num_obj])
    padded_boxes = torch.zeros((bsize, max_num_obj, 4))
    padded_classes = torch.zeros((bsize, max_num_obj))

    for i in range(bsize):
        padded_boxes[i, : num_obj[i], :] = boxes[i]
        padded_classes[i, : num_obj[i]] = gt_classes[i]

    return (
        torch.stack(im_data, 0),
        padded_boxes,
        padded_classes,
        torch.stack(num_obj, 0),
    )
